package com.gitee.wenbo0;

import cn.hutool.cache.CacheUtil;
import cn.hutool.cache.impl.TimedCache;
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import com.gitee.wenbo0.annotation.DictText;
import com.zaxxer.hikari.HikariDataSource;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.executor.resultset.DefaultResultSetHandler;
import org.apache.ibatis.executor.resultset.ResultSetHandler;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;

import javax.sql.DataSource;
import java.lang.reflect.Field;
import java.sql.*;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

/**
 * @author wenbo
 * @since 2022/7/1 15:31
 */
@Intercepts({@Signature(type = ResultSetHandler.class, method = "handleResultSets", args = {Statement.class})})
@Slf4j
public class DictTextResultSetHandlerPlugin implements Interceptor {
    private static final long DEFAULT_CACHE_TIME = 1000 * 60 * 5;
    private static final String DICT_SQL = "select %s from %s where %s = '%s' ";
    private static final String OTHER_WHERE = "and %s = '%s' ";
    private static final String COMMASEPARATE_SQL = "select group_concat(%s) from %s where %s in ('%s') ";
    protected static final String CACHE_SQL = "select * from %s";
    protected List<String> cacheTableNameList;
    protected TimedCache<String, List<Map<String, String>>> timedCache;

    public DictTextResultSetHandlerPlugin() {
    }

    public DictTextResultSetHandlerPlugin(String... tableNames) {
        this(DEFAULT_CACHE_TIME, tableNames);
    }

    /**
     * @param cacheTime  缓存时间 单位ms
     * @param tableNames 需要缓存的表名
     */
    public DictTextResultSetHandlerPlugin(long cacheTime, String... tableNames) {
        this.cacheTableNameList = Arrays.stream(tableNames).collect(Collectors.toList());
        timedCache = CacheUtil.newTimedCache(cacheTime);
    }

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        Object proceed = invocation.proceed();
        Object drsh = invocation.getTarget();
        if (proceed != null && drsh instanceof DefaultResultSetHandler) {
            if (!(proceed instanceof List))
                return proceed;
            DefaultResultSetHandler defaultResultSetHandler = (DefaultResultSetHandler) drsh;
            MappedStatement ms = (MappedStatement) getFieldValueByObject(defaultResultSetHandler, "mappedStatement");
            if (ms == null || ms.getResultMaps().size() != 1)
                return proceed;
            Class<?> type = ms.getResultMaps().get(0).getType();
            if (SpringContextHolder.getApplicationContext() == null) return proceed;
            DataSource dataSource = (DataSource) SpringContextHolder.getBean("dataSource");
//            try (Connection connection = DriverManager.getConnection(dataSource.getJdbcUrl(), dataSource.getUsername(), dataSource.getPassword())) {
            try (Connection connection = dataSource.getConnection()) {
                for (Object obj : (List) proceed) {
                    if (!type.isInstance(obj)) {
                        continue;
                    }
                    Field[] fields = type.getDeclaredFields();
                    for (Field field : fields) {
                        field.setAccessible(true);
                        Object value = field.get(obj);
                        if (value == null || (field.getAnnotation(DictText.class) != null && handleDictField(connection, obj, fields, field, value)))
                            continue;
                        field.setAccessible(false);
                    }
                }
            }
        }
        return proceed;
    }

    public static boolean listContainsIgnoreCase(List<String> list, String searchString) {
        for (String item : list) {
            if (item.equalsIgnoreCase(searchString)) {
                return true;
            }
        }
        return false;
    }

    private boolean handleDictField(Connection conn, Object obj, Field[] fields, Field field, Object value) throws Exception {
        DictText dictText = field.getAnnotation(DictText.class);
        String tableName = dictText.tableName();
        String keyColumn = dictText.keyColumn();
        String textColumn = dictText.textColumn();
        String target = dictText.target();
        if (StringUtils.isBlank(tableName) || StringUtils.isBlank(target))
            return true;
        String otherColumn = dictText.otherColumn();
        String otherValue = parseExpression(dictText.otherValue(), fields, obj);
        boolean commaSeparate = dictText.commaSeparate();
        String[] textColumnArr = textColumn.split(",");
        String[] targetArr = target.split(",");
        if (textColumnArr.length != targetArr.length) {
            log.warn("字段{}的textColumn和target数量不同", field.getName());
            return true;
        }

        if (listContainsIgnoreCase(this.cacheTableNameList, tableName)) {
            List<Map<String, String>> cacheList = timedCache.get(tableName.toLowerCase(), false);
            if (cacheList == null) {
                try (ResultSet rs = conn.prepareStatement(String.format(CACHE_SQL, tableName)).executeQuery()) {
                    List<Map<String, String>> list = rsToList(rs);
                    timedCache.put(tableName.toLowerCase(), list);
                    for (int i = 0, splitLength = textColumnArr.length; i < splitLength; i++) {
                        String textColumn1 = textColumnArr[i];
                        String target1 = targetArr[i];
                        String foreignText = getCache(value, keyColumn, textColumn1, otherColumn, otherValue, commaSeparate, list);
                        setFieldValue(obj, fields, target1, foreignText);
                    }
                }
            } else {
                for (int i = 0, splitLength = textColumnArr.length; i < splitLength; i++) {
                    String textColumn1 = textColumnArr[i];
                    String target1 = targetArr[i];
                    String foreignText = getCache(value, keyColumn, textColumn1, otherColumn, otherValue, commaSeparate, cacheList);
                    setFieldValue(obj, fields, target1, foreignText);
                }
            }
        } else {
            for (int i = 0, splitLength = textColumnArr.length; i < splitLength; i++) {
                String textColumn1 = textColumnArr[i];
                String target1 = targetArr[i];
                String sql, foreignText;
                //不需要其他字段条件时
                if (StringUtils.isBlank(otherValue)) {
                    //有逗号分割时
                    if (commaSeparate) {
                        String replace = value.toString().replace(",", "','");
                        sql = String.format(COMMASEPARATE_SQL, textColumn1, tableName, keyColumn, replace);
                    } else {
                        sql = String.format(DICT_SQL, textColumn1, tableName, keyColumn, value);
                    }
                } else {
                    //有逗号分割时
                    if (commaSeparate) {
                        String replace = value.toString().replace(",", "','");
                        sql = String.format(COMMASEPARATE_SQL + OTHER_WHERE, textColumn1,
                                tableName, keyColumn, replace, otherColumn, otherValue);
                    } else {
                        sql = String.format(DICT_SQL + OTHER_WHERE, textColumn1, tableName,
                                keyColumn, value, otherColumn, otherValue);
                    }
                }
                try (ResultSet rs = conn.prepareStatement(sql).executeQuery()) {
                    if (rs.next()) {
                        foreignText = rs.getString(1);
                    } else {
                        foreignText = null;
                    }
                }
                setFieldValue(obj, fields, target1, foreignText);
            }
        }
        return false;
    }

    private static void setFieldValue(Object obj, Field[] fields, String target, String foreignText) throws IllegalAccessException {
        List<Field> collect = Arrays.stream(fields).filter(field1 -> field1.getName().equals(target)).collect(Collectors.toList());
        if (collect.size() > 0) {
            Field targetField = collect.get(0);
            targetField.setAccessible(true);
            targetField.set(obj, foreignText);
            targetField.setAccessible(false);
        }
    }

    @SneakyThrows
    private String parseExpression(String otherValue, Field[] fields, Object obj) {
        String regex = "\\$\\{(.*?)}";
        Pattern pattern = Pattern.compile(regex);
        Matcher matcher;
        while ((matcher = pattern.matcher(otherValue)).find()) {
            String itemName = matcher.group().substring(2, matcher.group().length() - 1);
            List<Field> collect = Arrays.stream(fields).filter(field -> field.getName().equals(itemName)).collect(Collectors.toList());
            if (collect.size() > 0) {
                Field field = collect.get(0);
                field.setAccessible(true);
                String value = field.get(obj).toString();
                field.setAccessible(false);
                otherValue = otherValue.replace(matcher.group(), value);
            }
        }
        return otherValue;
    }

    private static String getCache(Object value, String keyColumn, String textColumn, String otherColumn,
                                   String otherValue, boolean commaSeparate, List<Map<String, String>> cacheList) {
        String foreignText = null;
        if (StringUtils.isBlank(otherValue)) {
            if (commaSeparate) {
                StringBuilder stringBuilder = new StringBuilder();
                for (String s : value.toString().split(",")) {
                    for (Map<String, String> cacheObj : cacheList) {
                        Object keyObj = cacheObj.get(keyColumn);
                        Object textObj = cacheObj.get(textColumn);
                        if (keyObj == null) {
                            throw new RuntimeException("请检查配置的keyColumn是否正确(数据库是否区分大小写)");
                        }
                        if (s.equals(String.valueOf(keyObj))) {
                            if (textObj == null) {
                                stringBuilder.append("null").append(",");
                            } else {
                                stringBuilder.append(textObj).append(",");
                            }
                        }
                    }
                }
                foreignText = stringBuilder.length() > 0 ? stringBuilder.substring(0, stringBuilder.length() - 1) : null;
            } else {
                for (Map<String, String> cacheObj : cacheList) {
                    Object keyObj = cacheObj.get(keyColumn);
                    Object textObj = cacheObj.get(textColumn);
                    if (keyObj == null) {
                        throw new RuntimeException("请检查配置的keyColumn是否正确(数据库是否区分大小写)");
                    }
                    if (String.valueOf(value).equals(String.valueOf(keyObj))) {
                        if (textObj != null) {
                            foreignText = textObj.toString();
                        }
                        break;
                    }
                }
            }
        } else {
            if (commaSeparate) {
                StringBuilder stringBuilder = new StringBuilder();
                for (String s : value.toString().split(",")) {
                    for (Map<String, String> cacheObj : cacheList) {
                        Object keyObj = cacheObj.get(keyColumn);
                        Object textObj = cacheObj.get(textColumn);
                        Object otherObj = cacheObj.get(otherColumn);
                        if (keyObj == null || otherObj == null) {
                            throw new RuntimeException("请检查配置的keyColumn otherColumn是否正确(数据库是否区分大小写)");
                        }
                        if (s.equals(String.valueOf(keyObj)) && otherValue.equals(String.valueOf(otherObj))) {
                            if (textObj == null) {
                                stringBuilder.append("null").append(",");
                            } else {
                                stringBuilder.append(textObj).append(",");
                            }
                        }
                    }
                }
                foreignText = stringBuilder.length() > 0 ? stringBuilder.substring(0, stringBuilder.length() - 1) : null;
            } else {
                for (Map<String, String> cacheObj : cacheList) {
                    Object keyObj = cacheObj.get(keyColumn);
                    Object textObj = cacheObj.get(textColumn);
                    Object otherObj = cacheObj.get(otherColumn);
                    if (keyObj == null || otherObj == null) {
                        throw new RuntimeException("请检查配置的keyColumn otherColumn是否正确(数据库是否区分大小写)");
                    }
                    if (String.valueOf(value).equals(String.valueOf(keyObj)) && otherValue.equals(String.valueOf(otherObj))) {
                        if (textObj != null) {
                            foreignText = textObj.toString();
                        }
                        break;
                    }
                }
            }
        }
        return foreignText;
    }

    protected static List<Map<String, String>> rsToList(ResultSet rs) {
        List<Map<String, String>> list = new ArrayList<>();
        try {
            ResultSetMetaData md = rs.getMetaData();
            int columnCount = md.getColumnCount();
            while (rs.next()) {
                Map<String, String> rowData = new HashMap<>();
                for (int i = 1; i <= columnCount; i++) {
                    rowData.put(md.getColumnName(i), rs.getString(i));
                }
                list.add(rowData);
            }
        } catch (SQLException e) {
            e.printStackTrace();
        }
        return list;
    }

    private static Object getFieldValueByObject(Object object, String targetFieldName) {
        // 获取该对象的Class
        Class objClass = object.getClass();
        // 初始化返回值
        Object result;
        // 获取所有的属性数组
        Field[] fields = objClass.getDeclaredFields();
        for (Field field : fields) {
            try {
                if (field.getName().equals(targetFieldName)) {
                    field.setAccessible(true);
                    result = field.get(object);
                    return result; // 通过反射拿到该属性在此对象中的值(也可能是个对象)
                }
            } catch (SecurityException e) {
                // 安全性异常
                e.printStackTrace();
            } catch (IllegalArgumentException e) {
                // 非法参数
                e.printStackTrace();
            } catch (IllegalAccessException e) {
                // 无访问权限
                e.printStackTrace();
            }
        }
        return null;
    }
}
